from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
import os
import zipfile
import torch
import torchvision
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from torch.utils.data import random_split, DataLoader
# Ruta del archivo zip (ajusta esta ruta según sea necesario)
zip_file = '/content/drive/MyDrive/newarchive.zip'
extract_dir = '/content'
# Extraer el archivo zip
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
zip_ref.extractall(extract_dir)
# Directorio de datos (ajusta esta ruta según sea necesario)
data_dir = '/content/garbage_classification'
# Transformaciones
transformations = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
# Cargar dataset
dataset = ImageFolder(data_dir, transform=transformations)
# Determinar tamaños de los conjuntos de datos
num_train = int(0.7 * len(dataset))
num_val = int(0.1 * len(dataset))
num_test = len(dataset) - num_train - num_val
# Dividir el dataset
train_ds, val_ds, test_ds = random_split(dataset, [num_train, num_val, num_test])
# Crear DataLoader
batch_size = 32
train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_dl = DataLoader(val_ds, batch_size*2, num_workers=4, pin_memory=True)
test_dl = DataLoader(test_ds, batch_size*2, num_workers=4, pin_memory=True)
# Verificar el número de imágenes en cada conjunto de datos
print(f'Tamaño del conjunto de entrenamiento: {len(train_ds)}')
print(f'Tamaño del conjunto de validación: {len(val_ds)}')
print(f'Tamaño del conjunto de prueba: {len(test_ds)}')
Tamaño del conjunto de entrenamiento: 10860 Tamaño del conjunto de validación: 1551 Tamaño del conjunto de prueba: 3104
/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:558: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary. warnings.warn(_create_warning_msg(
#### Paso 2: Definir el Modelo
import torchvision.models as models
import torch.nn as nn
import torch.nn.functional as F
def accuracy(outputs, labels):
_, preds = torch.max(outputs, dim=1)
return torch.tensor(torch.sum(preds == labels).item() / len(preds))
class ImageClassificationBase(nn.Module):
def training_step(self, batch):
images, labels = batch
out = self(images) # Generate predictions
loss = F.cross_entropy(out, labels) # Calculate loss
return loss
def validation_step(self, batch):
images, labels = batch
out = self(images) # Generate predictions
loss = F.cross_entropy(out, labels) # Calculate loss
acc = accuracy(out, labels) # Calculate accuracy
return {'val_loss': loss.detach(), 'val_acc': acc}
def validation_epoch_end(self, outputs):
batch_losses = [x['val_loss'] for x in outputs]
epoch_loss = torch.stack(batch_losses).mean() # Combine losses
batch_accs = [x['val_acc'] for x in outputs]
epoch_acc = torch.stack(batch_accs).mean() # Combine accuracies
return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}
def epoch_end(self, epoch, result):
print(f"Epoch {epoch+1}: train_loss: {result['train_loss']:.4f}, val_loss: {result['val_loss']:.4f}, val_acc: {result['val_acc']:.4f}")
class ResNet(ImageClassificationBase):
def __init__(self):
super().__init__()
self.network = models.resnet50(pretrained=True)
num_ftrs = self.network.fc.in_features
self.network.fc = nn.Linear(num_ftrs, len(dataset.classes))
def forward(self, xb):
return self.network(xb)
# Mover modelo a dispositivo GPU si está disponible
def get_default_device():
if torch.cuda.is_available():
return torch.device('cuda')
else:
return torch.device('cpu')
def to_device(data, device):
if isinstance(data, (list,tuple)):
return [to_device(x, device) for x in data]
return data.to(device, non_blocking=True)
class DeviceDataLoader():
def __init__(self, dl, device):
self.dl = dl
self.device = device
def __iter__(self):
for b in self.dl:
yield to_device(b, self.device)
def __len__(self):
return len(self.dl)
device = get_default_device()
train_dl = DeviceDataLoader(train_dl, device)
val_dl = DeviceDataLoader(val_dl, device)
test_dl = DeviceDataLoader(test_dl, device)
model = ResNet()
model = to_device(model, device)
@torch.no_grad()
def evaluate(model, val_loader):
model.eval()
outputs = [model.validation_step(batch) for batch in val_loader]
return model.validation_epoch_end(outputs)
def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD):
history = []
optimizer = opt_func(model.parameters(), lr)
for epoch in range(epochs):
model.train()
train_losses = []
for batch in train_loader:
loss = model.training_step(batch)
train_losses.append(loss)
loss.backward()
optimizer.step()
optimizer.zero_grad()
result = evaluate(model, val_loader)
result['train_loss'] = torch.stack(train_losses).mean().item()
model.epoch_end(epoch, result)
history.append(result)
return history
# Configuración de hiperparámetros y entrenamiento
num_epochs = 8
opt_func = torch.optim.Adam
lr = 5.5e-5
history = fit(num_epochs, lr, model, train_dl, val_dl, opt_func)
# Guardar el modelo entrenado
model_path = "resnet50_garbage_classification.pth"
torch.save(model.state_dict(), model_path)
print(f"Modelo guardado en {model_path}")
# Graficar precisión y pérdida
import matplotlib.pyplot as plt
def plot_accuracies(history):
accuracies = [x['val_acc'] for x in history]
plt.plot(accuracies, '-x')
plt.xlabel('epoch')
plt.ylabel('accuracy')
plt.title('Accuracy vs. No. of epochs')
plot_accuracies(history)
def plot_losses(history):
train_losses = [x.get('train_loss') for x in history]
val_losses = [x['val_loss'] for x in history]
plt.plot(train_losses, '-bx')
plt.plot(val_losses, '-rx')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend(['Training', 'Validation'])
plt.title('Loss vs. No. of epochs')
plot_losses(history)
/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead. warnings.warn( /usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet50_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet50_Weights.DEFAULT` to get the most up-to-date weights. warnings.warn(msg) Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth 100%|██████████| 97.8M/97.8M [00:00<00:00, 134MB/s]
Epoch 1: train_loss: 0.3769, val_loss: 0.1374, val_acc: 0.9567 Epoch 2: train_loss: 0.0808, val_loss: 0.1102, val_acc: 0.9706 Epoch 3: train_loss: 0.0422, val_loss: 0.0980, val_acc: 0.9731 Epoch 4: train_loss: 0.0370, val_loss: 0.1108, val_acc: 0.9711 Epoch 5: train_loss: 0.0294, val_loss: 0.0858, val_acc: 0.9756 Epoch 6: train_loss: 0.0343, val_loss: 0.1051, val_acc: 0.9730 Epoch 7: train_loss: 0.0185, val_loss: 0.1216, val_acc: 0.9673 Epoch 8: train_loss: 0.0264, val_loss: 0.1119, val_acc: 0.9748 Modelo guardado en resnet50_garbage_classification.pth
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
def predict_image(img, model):
xb = to_device(img.unsqueeze(0), device)
yb = model(xb)
prob, preds = torch.max(yb, dim=1)
return dataset.classes[preds[0].item()]
def predict_images_from_each_folder(base_dir, folders, model, num_images=5):
for folder in folders:
folder_path = os.path.join(base_dir, folder)
image_paths = list(Path(folder_path).glob('*.jpg')) # Ajusta la extensión de archivo si es necesario
image_paths = image_paths[:num_images] # Limitar a num_images
for image_path in image_paths:
image = Image.open(image_path)
example_image = transformations(image)
plt.imshow(example_image.permute(1, 2, 0))
plt.title(f"Prediction: {predict_image(example_image, model)}")
plt.show()
print(f"The image {image_path} resembles", predict_image(example_image, model) + ".")
# Carpetas a predecir
folders = ['battery', 'biological', 'brown-glass', 'cardboard', 'clothes', 'green-glass', 'metal', 'paper', 'plastic', 'shoes', 'trash', 'white-glass']
# Ruta base de las carpetas
base_dir = data_dir
# Realizar predicciones
predict_images_from_each_folder(base_dir, folders, model)
The image /content/garbage_classification/battery/battery30.jpg resembles battery.
The image /content/garbage_classification/battery/battery803.jpg resembles battery.
The image /content/garbage_classification/battery/battery270.jpg resembles battery.
The image /content/garbage_classification/battery/battery493.jpg resembles battery.
The image /content/garbage_classification/battery/battery755.jpg resembles battery.
The image /content/garbage_classification/biological/biological358.jpg resembles biological.
The image /content/garbage_classification/biological/biological8.jpg resembles biological.
The image /content/garbage_classification/biological/biological242.jpg resembles biological.
The image /content/garbage_classification/biological/biological106.jpg resembles biological.
The image /content/garbage_classification/biological/biological916.jpg resembles biological.
The image /content/garbage_classification/brown-glass/brown-glass567.jpg resembles brown-glass.
The image /content/garbage_classification/brown-glass/brown-glass369.jpg resembles brown-glass.
The image /content/garbage_classification/brown-glass/brown-glass513.jpg resembles brown-glass.
The image /content/garbage_classification/brown-glass/brown-glass242.jpg resembles brown-glass.
The image /content/garbage_classification/brown-glass/brown-glass510.jpg resembles brown-glass.
The image /content/garbage_classification/cardboard/cardboard732.jpg resembles cardboard.
The image /content/garbage_classification/cardboard/cardboard703.jpg resembles cardboard.
The image /content/garbage_classification/cardboard/cardboard542.jpg resembles cardboard.
The image /content/garbage_classification/cardboard/cardboard835.jpg resembles paper.
The image /content/garbage_classification/cardboard/cardboard618.jpg resembles cardboard.
The image /content/garbage_classification/clothes/clothes864.jpg resembles clothes.
The image /content/garbage_classification/clothes/clothes5003.jpg resembles clothes.
The image /content/garbage_classification/clothes/clothes1114.jpg resembles clothes.
The image /content/garbage_classification/clothes/clothes284.jpg resembles clothes.
The image /content/garbage_classification/clothes/clothes4360.jpg resembles clothes.
The image /content/garbage_classification/green-glass/green-glass598.jpg resembles green-glass.
The image /content/garbage_classification/green-glass/green-glass124.jpg resembles green-glass.
The image /content/garbage_classification/green-glass/green-glass247.jpg resembles green-glass.
The image /content/garbage_classification/green-glass/green-glass123.jpg resembles green-glass.
The image /content/garbage_classification/green-glass/green-glass45.jpg resembles green-glass.
The image /content/garbage_classification/metal/metal686.jpg resembles metal.
The image /content/garbage_classification/metal/metal343.jpg resembles metal.
The image /content/garbage_classification/metal/metal440.jpg resembles metal.
The image /content/garbage_classification/metal/metal445.jpg resembles metal.
The image /content/garbage_classification/metal/metal443.jpg resembles metal.
The image /content/garbage_classification/paper/paper200.jpg resembles paper.
The image /content/garbage_classification/paper/paper235.jpg resembles paper.
The image /content/garbage_classification/paper/paper185.jpg resembles paper.
The image /content/garbage_classification/paper/paper35.jpg resembles paper.
The image /content/garbage_classification/paper/paper267.jpg resembles paper.
The image /content/garbage_classification/plastic/plastic228.jpg resembles plastic.
The image /content/garbage_classification/plastic/plastic384.jpg resembles plastic.
The image /content/garbage_classification/plastic/plastic665.jpg resembles plastic.
The image /content/garbage_classification/plastic/plastic764.jpg resembles plastic.
The image /content/garbage_classification/plastic/plastic421.jpg resembles plastic.
The image /content/garbage_classification/shoes/shoes501.jpg resembles shoes.
The image /content/garbage_classification/shoes/shoes172.jpg resembles shoes.
The image /content/garbage_classification/shoes/shoes1092.jpg resembles shoes.
The image /content/garbage_classification/shoes/shoes1281.jpg resembles shoes.
The image /content/garbage_classification/shoes/shoes1038.jpg resembles shoes.
The image /content/garbage_classification/trash/trash626.jpg resembles trash.
The image /content/garbage_classification/trash/trash188.jpg resembles trash.
The image /content/garbage_classification/trash/trash22.jpg resembles trash.
The image /content/garbage_classification/trash/trash306.jpg resembles trash.
The image /content/garbage_classification/trash/trash529.jpg resembles trash.
The image /content/garbage_classification/white-glass/white-glass410.jpg resembles white-glass.
The image /content/garbage_classification/white-glass/white-glass249.jpg resembles white-glass.
The image /content/garbage_classification/white-glass/white-glass389.jpg resembles white-glass.
The image /content/garbage_classification/white-glass/white-glass66.jpg resembles white-glass.
The image /content/garbage_classification/white-glass/white-glass600.jpg resembles white-glass.
!pip uninstall torch torchvision torchaudio
!pip install torch torchvision torchaudio
Found existing installation: torch 2.3.1+cu121
Uninstalling torch-2.3.1+cu121:
Would remove:
/usr/local/bin/convert-caffe2-to-onnx
/usr/local/bin/convert-onnx-to-caffe2
/usr/local/bin/torchrun
/usr/local/lib/python3.10/dist-packages/functorch/*
/usr/local/lib/python3.10/dist-packages/torch-2.3.1+cu121.dist-info/*
/usr/local/lib/python3.10/dist-packages/torch/*
/usr/local/lib/python3.10/dist-packages/torchgen/*
Proceed (Y/n)? y
Y
Successfully uninstalled torch-2.3.1+cu121
Found existing installation: torchvision 0.18.1+cu121
Uninstalling torchvision-0.18.1+cu121:
Would remove:
/usr/local/lib/python3.10/dist-packages/torchvision-0.18.1+cu121.dist-info/*
/usr/local/lib/python3.10/dist-packages/torchvision.libs/libcudart.7ec1eba6.so.12
/usr/local/lib/python3.10/dist-packages/torchvision.libs/libjpeg.ceea7512.so.62
/usr/local/lib/python3.10/dist-packages/torchvision.libs/libnvjpeg.f00ca762.so.12
/usr/local/lib/python3.10/dist-packages/torchvision.libs/libpng16.7f72a3c5.so.16
/usr/local/lib/python3.10/dist-packages/torchvision.libs/libz.4e87b236.so.1
/usr/local/lib/python3.10/dist-packages/torchvision/*
Proceed (Y/n)? Successfully uninstalled torchvision-0.18.1+cu121
Found existing installation: torchaudio 2.3.1+cu121
Uninstalling torchaudio-2.3.1+cu121:
Would remove:
/usr/local/lib/python3.10/dist-packages/torchaudio-2.3.1+cu121.dist-info/*
/usr/local/lib/python3.10/dist-packages/torchaudio/*
/usr/local/lib/python3.10/dist-packages/torio/*
Proceed (Y/n)? y
Successfully uninstalled torchaudio-2.3.1+cu121
Collecting torch
Downloading torch-2.4.0-cp310-cp310-manylinux1_x86_64.whl.metadata (26 kB)
Collecting torchvision
Downloading torchvision-0.19.0-cp310-cp310-manylinux1_x86_64.whl.metadata (6.0 kB)
Collecting torchaudio
Downloading torchaudio-2.4.0-cp310-cp310-manylinux1_x86_64.whl.metadata (6.4 kB)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.15.4)
Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)
Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.13.1)
Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.3)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2024.6.1)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)
Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch)
Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.2.106 (from torch)
Using cached nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cusolver-cu12==11.4.5.107 (from torch)
Using cached nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cusparse-cu12==12.1.0.106 (from torch)
Using cached nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-nccl-cu12==2.20.5 (from torch)
Using cached nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl.metadata (1.8 kB)
Collecting nvidia-nvtx-cu12==12.1.105 (from torch)
Using cached nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.7 kB)
Collecting triton==3.0.0 (from torch)
Downloading triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.3 kB)
Collecting nvidia-nvjitlink-cu12 (from nvidia-cusolver-cu12==11.4.5.107->torch)
Downloading nvidia_nvjitlink_cu12-12.5.82-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torchvision) (1.25.2)
Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/dist-packages (from torchvision) (9.4.0)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.5)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)
Downloading torch-2.4.0-cp310-cp310-manylinux1_x86_64.whl (797.2 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 797.2/797.2 MB 2.4 MB/s eta 0:00:00
Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)
Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 664.8/664.8 MB 2.9 MB/s eta 0:00:00
Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)
Using cached nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)
Using cached nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)
Using cached nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)
Using cached nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl (176.2 MB)
Using cached nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)
Downloading triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (209.4 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 209.4/209.4 MB 2.0 MB/s eta 0:00:00
Downloading torchvision-0.19.0-cp310-cp310-manylinux1_x86_64.whl (7.0 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.0/7.0 MB 108.3 MB/s eta 0:00:00
Downloading torchaudio-2.4.0-cp310-cp310-manylinux1_x86_64.whl (3.4 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.4/3.4 MB 95.8 MB/s eta 0:00:00
Downloading nvidia_nvjitlink_cu12-12.5.82-py3-none-manylinux2014_x86_64.whl (21.3 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 21.3/21.3 MB 95.1 MB/s eta 0:00:00
Installing collected packages: triton, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, torch, torchvision, torchaudio
Attempting uninstall: triton
Found existing installation: triton 2.3.1
Uninstalling triton-2.3.1:
Successfully uninstalled triton-2.3.1
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
fastai 2.7.15 requires torch<2.4,>=1.10, but you have torch 2.4.0 which is incompatible.
Successfully installed nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.20.5 nvidia-nvjitlink-cu12-12.5.82 nvidia-nvtx-cu12-12.1.105 torch-2.4.0 torchaudio-2.4.0 torchvision-0.19.0 triton-3.0.0
!pip install onnx
Requirement already satisfied: onnx in /usr/local/lib/python3.10/dist-packages (1.16.1) Requirement already satisfied: numpy>=1.20 in /usr/local/lib/python3.10/dist-packages (from onnx) (1.25.2) Requirement already satisfied: protobuf>=3.20.2 in /usr/local/lib/python3.10/dist-packages (from onnx) (3.20.3)
import torch
import torchvision.models as models
import collections
import torch.nn as nn
# Cargar el modelo entrenado con el número correcto de clases
# num_classes = 6 # Reemplaza con el número de clases en tu dataset
num_classes = 12 # The original model was trained on 12 classes
model = models.resnet50(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, num_classes)
# Cargar el state_dict guardado
state_dict = torch.load('/content/resnet50_garbage_classification.pth')
# Eliminar el prefijo 'network.' de las claves del state_dict
new_state_dict = collections.OrderedDict()
for k, v in state_dict.items():
if k.startswith('network.'):
name = k[8:] # Eliminar 'network.' del nombre de la clave
else:
name = k
new_state_dict[name] = v
# Cargar el state_dict modificado en el modelo
model.load_state_dict(new_state_dict)
model.eval()
# Crear una entrada de prueba para la exportación
dummy_input = torch.randn(1, 3, 224, 224)
# Exportar el modelo a ONNX
torch.onnx.export(model, dummy_input, "resnet50_garbage_classification.onnx",
input_names=["input"], output_names=["output"])
/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
warnings.warn(
/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.
warnings.warn(msg)
<ipython-input-4-8bb255fc64c8>:14: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load('/content/resnet50_garbage_classification.pth')
pip install onnx onnx-tf
Requirement already satisfied: onnx in /usr/local/lib/python3.10/dist-packages (1.16.1) Collecting onnx-tf Downloading onnx_tf-1.10.0-py3-none-any.whl.metadata (510 bytes) Requirement already satisfied: numpy>=1.20 in /usr/local/lib/python3.10/dist-packages (from onnx) (1.25.2) Requirement already satisfied: protobuf>=3.20.2 in /usr/local/lib/python3.10/dist-packages (from onnx) (3.20.3) Requirement already satisfied: PyYAML in /usr/local/lib/python3.10/dist-packages (from onnx-tf) (6.0.1) Collecting tensorflow-addons (from onnx-tf) Downloading tensorflow_addons-0.23.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.8 kB) Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from tensorflow-addons->onnx-tf) (24.1) Collecting typeguard<3.0.0,>=2.7 (from tensorflow-addons->onnx-tf) Downloading typeguard-2.13.3-py3-none-any.whl.metadata (3.6 kB) Downloading onnx_tf-1.10.0-py3-none-any.whl (226 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 226.1/226.1 kB 18.5 MB/s eta 0:00:00 Downloading tensorflow_addons-0.23.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (611 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 611.8/611.8 kB 43.7 MB/s eta 0:00:00 Downloading typeguard-2.13.3-py3-none-any.whl (17 kB) Installing collected packages: typeguard, tensorflow-addons, onnx-tf Attempting uninstall: typeguard Found existing installation: typeguard 4.3.0 Uninstalling typeguard-4.3.0: Successfully uninstalled typeguard-4.3.0 ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. inflect 7.3.1 requires typeguard>=4.0.1, but you have typeguard 2.13.3 which is incompatible. Successfully installed onnx-tf-1.10.0 tensorflow-addons-0.23.0 typeguard-2.13.3
import onnx
from onnx_tf.backend import prepare
# Cargar el modelo ONNX
onnx_model = onnx.load("/content/resnet50_garbage_classification.onnx")
# Convertir a TensorFlow
tf_rep = prepare(onnx_model)
# Exportar el modelo a formato de TensorFlow
tf_rep.export_graph("resnet50_garbage_classification")
/usr/local/lib/python3.10/dist-packages/tensorflow_addons/utils/tfa_eol_msg.py:23: UserWarning: TensorFlow Addons (TFA) has ended development and introduction of new features. TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024. Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). For more information see: https://github.com/tensorflow/addons/issues/2807 warnings.warn( INFO:absl:Function `__call__` contains input name(s) x, y with unsupported characters which will be renamed to transpose_161_x, add_52_y in the SavedModel. INFO:absl:Found untraced functions such as gen_tensor_dict while saving (showing 1 of 1). These functions will not be directly callable after loading. INFO:absl:Writing fingerprint to resnet50_garbage_classification/fingerprint.pb
pip install tensorflow
Requirement already satisfied: tensorflow in /usr/local/lib/python3.10/dist-packages (2.15.0) Requirement already satisfied: absl-py>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (1.4.0) Requirement already satisfied: astunparse>=1.6.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (1.6.3) Requirement already satisfied: flatbuffers>=23.5.26 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (24.3.25) Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (0.6.0) Requirement already satisfied: google-pasta>=0.1.1 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (0.2.0) Requirement already satisfied: h5py>=2.9.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (3.11.0) Requirement already satisfied: libclang>=13.0.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (18.1.1) Requirement already satisfied: ml-dtypes~=0.2.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (0.2.0) Requirement already satisfied: numpy<2.0.0,>=1.23.5 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (1.25.2) Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (3.3.0) Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from tensorflow) (24.1) Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (3.20.3) Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from tensorflow) (71.0.4) Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (1.16.0) Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (2.4.0) Requirement already satisfied: typing-extensions>=3.6.6 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (4.12.2) Requirement already satisfied: wrapt<1.15,>=1.11.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (1.14.1) Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (0.37.1) Requirement already satisfied: grpcio<2.0,>=1.24.3 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (1.64.1) Requirement already satisfied: tensorboard<2.16,>=2.15 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (2.15.2) Requirement already satisfied: tensorflow-estimator<2.16,>=2.15.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (2.15.0) Requirement already satisfied: keras<2.16,>=2.15.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (2.15.0) Requirement already satisfied: wheel<1.0,>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from astunparse>=1.6.0->tensorflow) (0.43.0) Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.10/dist-packages (from tensorboard<2.16,>=2.15->tensorflow) (2.27.0) Requirement already satisfied: google-auth-oauthlib<2,>=0.5 in /usr/local/lib/python3.10/dist-packages (from tensorboard<2.16,>=2.15->tensorflow) (1.2.1) Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.10/dist-packages (from tensorboard<2.16,>=2.15->tensorflow) (3.6) Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard<2.16,>=2.15->tensorflow) (2.31.0) Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard<2.16,>=2.15->tensorflow) (0.7.2) Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from tensorboard<2.16,>=2.15->tensorflow) (3.0.3) Requirement already satisfied: cachetools<6.0,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow) (5.4.0) Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow) (0.4.0) Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow) (4.9) Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from google-auth-oauthlib<2,>=0.5->tensorboard<2.16,>=2.15->tensorflow) (1.3.1) Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow) (3.3.2) Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow) (3.7) Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow) (2.0.7) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow) (2024.7.4) Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.10/dist-packages (from werkzeug>=1.0.1->tensorboard<2.16,>=2.15->tensorflow) (2.1.5) Requirement already satisfied: pyasn1<0.7.0,>=0.4.6 in /usr/local/lib/python3.10/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow) (0.6.0) Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.10/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<2,>=0.5->tensorboard<2.16,>=2.15->tensorflow) (3.2.2)
import tensorflow as tf
# Convertir el modelo a TensorFlow Lite
converter = tf.lite.TFLiteConverter.from_saved_model("resnet50_garbage_classification")
tflite_model = converter.convert()
# Guardar el modelo TFLite
with open("resnet50_garbage_classification.tflite", "wb") as f:
f.write(tflite_model)
! jupyter nbconvert --to html garbageclassification.ipynb
[NbConvertApp] Converting notebook garbageclassification.ipynb to html
[NbConvertApp] ERROR | Notebook JSON is invalid: Additional properties are not allowed ('metadata' was unexpected)
Failed validating 'additionalProperties' in stream:
On instance['cells'][6]['outputs'][0]:
{'metadata': {'tags': None},
'name': 'stderr',
'output_type': 'stream',
'text': '/usr/local/lib/python3.10/dist-packages/torchvision/models/_util...'}
/usr/local/lib/python3.10/dist-packages/nbconvert/filters/widgetsdatatypefilter.py:71: UserWarning: Your element with mimetype(s) dict_keys(['application/vnd.colab-display-data+json']) is not able to be represented.
warn(
[NbConvertApp] Writing 10713332 bytes to garbageclassification.html